Exploring the Google QuickDraw Dataset with SketchRNN (Part 3)
This is the third part in a series of notes on my exploration of the recently released Google QuickDraw dataset 1, using the concurrently released SketchRNN model.
The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.
SketchRNN is a very impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly assembles many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, among others.
Again, I've discarded the markdown cells or codeblocks that were intended to explain or demonstrate something, retaining only the code I need to run the experiments in this notebook. Everything up to the section Principal Component Analysis in the Latent Space was copied directly from previous notebooks. Feel free to skip right ahead to that section, as that is where the really interesting analysis happens. Everything before was mostly utility functions to facilitate visualization. Here are links to the first and second note.
These notebooks were derived from the notebook included with the code release. I've made significant stylistic changes and some minor changes to ensure Python 3 forward compatibility2.
This is somewhat misleading, as we are mainly exploring the Aaron Koblin Sheep Market (aaron-sheep) dataset, a smaller lightweight dataset provided with the
sketch-rnnrelease, along with a notebook that demos various models already pre-trained on this dataset. It is a natural starting point for experimenting withsketch-rnn. Since the dataset schema is the same as that of the QuickDraw dataset, operations performed here on this dataset are done without loss of generality.↩Magenta only supports Python 2 currently.↩
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import tensorflow as tf
from matplotlib.animation import FuncAnimation
from matplotlib.path import Path
from matplotlib import rc
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from itertools import product
from six.moves import map, zip
from magenta.models.sketch_rnn.sketch_rnn_train import \
(load_env,
load_checkpoint,
reset_graph,
download_pretrained_models,
PRETRAINED_MODELS_URL)
from magenta.models.sketch_rnn.model import Model, sample
from magenta.models.sketch_rnn.utils import (lerp,
slerp,
get_bounds,
to_big_strokes,
to_normal_strokes)
# For inine display of animation
# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')
# set numpy output to something sensible
np.set_printoptions(precision=8,
edgeitems=6,
linewidth=200,
suppress=True)
tf.logging.info("TensorFlow Version: {}".format(tf.__version__))
Getting the Pre-Trained Models and Data¶
DATA_DIR = ('http://github.com/hardmaru/sketch-rnn-datasets/'
'raw/master/aaron_sheep/')
MODELS_ROOT_DIR = '/tmp/sketch_rnn/models'
DATA_DIR
PRETRAINED_MODELS_URL
download_pretrained_models(
models_root_dir=MODELS_ROOT_DIR,
pretrained_models_url=PRETRAINED_MODELS_URL)
We look at the layer normalized model trained on the aaron_sheep dataset for now.
MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'
(train_set,
valid_set,
test_set,
hps_model,
eval_hps_model,
sample_hps_model) = load_env(DATA_DIR, MODEL_DIR)
class SketchPath(Path):
def __init__(self, data, factor=.2, *args, **kwargs):
vertices = np.cumsum(data[::, :-1], axis=0) / factor
codes = np.roll(self.to_code(data[::,-1].astype(int)),
shift=1)
codes[0] = Path.MOVETO
super(SketchPath, self).__init__(vertices,
codes,
*args,
**kwargs)
@staticmethod
def to_code(cmd):
# if cmd == 0, the code is LINETO
# if cmd == 1, the code is MOVETO (which is LINETO - 1)
return Path.LINETO - cmd
def draw(sketch_data, factor=.2, pad=(10, 10), ax=None):
if ax is None:
ax = plt.gca()
x_pad, y_pad = pad
x_pad //= 2
y_pad //= 2
x_min, x_max, y_min, y_max = get_bounds(data=sketch_data,
factor=factor)
ax.set_xlim(x_min-x_pad, x_max+x_pad)
ax.set_ylim(y_max+y_pad, y_min-y_pad)
sketch = SketchPath(sketch_data)
patch = patches.PathPatch(sketch, facecolor='none')
ax.add_patch(patch)
Load pre-trained models¶
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess=sess, checkpoint_path=MODEL_DIR)
def encode(input_strokes):
strokes = to_big_strokes(input_strokes).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
z = sess.run(eval_model.batch_z,
feed_dict={
eval_model.input_data: [strokes],
eval_model.sequence_lengths: seq_len})[0]
return z
def decode(z_input=None, temperature=.1, factor=.2):
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(
sess,
sample_model,
seq_len=eval_model.hps.max_seq_len,
temperature=temperature, z=z)
return to_normal_strokes(sample_strokes)
Principal Component Analysis in the Latent Space¶
What do you call a baby eigensheep? A lamb, duh.
We encode all of the sketches in the test set into their learned 128-dimensional latent space representations.
Z = np.vstack(map(encode, test_set.strokes))
Z.shape
Then, we find the top two principal axes that represent the direction of maximum variance in the data encoded in the latent space.
pca = PCA(n_components=2)
pca.fit(Z)
The two components each account for about 2% of the variance
pca.explained_variance_ratio_
Let's project the data from the 128-dimensional latent space to the lower 2-dimensional space spanned by the first 2 principal components
Z_pca = pca.transform(Z)
Z_pca.shape
fig, ax = plt.subplots()
ax.scatter(*Z_pca.T)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
We'd like to visualize the original sketches at their corresponding points on this plot. Each point corresponds to the latent code of a sketch, reduced to 2 dimensions. However, the plot is slightly too dense to fit sufficiently large sketches without overlapping them. Therefore, we restrict our attention to a smaller region that encompasses 80% of the data points, discarding those outside of the 5th and 95th percentiles in both aces. The blue shaded rectangle highlights our region of interest.
((pc1_min, pc2_min),
(pc1_max, pc2_max)) = np.percentile(Z_pca, q=[5, 95], axis=0)
roi_rect = patches.Rectangle(xy=(pc1_min, pc2_min),
width=pc1_max-pc1_min,
height=pc2_max-pc2_min, alpha=.4)
fig, ax = plt.subplots()
ax.scatter(*Z_pca.T)
ax.add_patch(roi_rect)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(pc1_min, pc1_max)
ax.set_ylim(pc2_min, pc2_max)
for i, sketch in enumerate(test_set.strokes):
sketch_path = SketchPath(sketch, factor=7e+1)
sketch_path.vertices[::,1] *= -1
sketch_path.vertices += Z_pca[i]
patch = patches.PathPatch(sketch_path, facecolor='none')
ax.add_patch(patch)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.savefig("../../files/sketchrnn/aaron_sheep_pca.svg",
format="svg")
Remark: There is a far more clever way to produce this plot that involves the use of matplotlib Transformations and the Collections API, by instantiating PathCollection with the keyword arguments offets defined as the array of projected locations (in this instance Z_pca). However, I was unable to get this to work (maybe I was working in the wrong coordinate system?). In either case, if you find a slicker way to produce this plot, please let me know. I'd love to learn about your approach!
Linear Interpolation in PCA¶
We generate of grid of 100 linearly spaced points in the subspace spanned by the first 2 principal axes, in the rectangular region of interest defined previously (starting at the 5th percentile and ending at the 95th percentile). The grid of points are shown in orange below, overlayed on top of the test data points.
pc1 = lerp(pc1_min, pc1_max, np.linspace(0, 1, 10))
pc2 = lerp(pc2_min, pc2_max, np.linspace(0, 1, 10))
pc1_mesh, pc2_mesh = np.meshgrid(pc1, pc2)
fig, ax = plt.subplots()
ax.set_xlim(pc1_min-.5, pc1_max+.5)
ax.set_ylim(pc2_min-.5, pc2_max+.5)
ax.scatter(*Z_pca.T)
ax.scatter(pc1_mesh, pc2_mesh)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
Next, we're interested in applying the inverse transform of the PCA to project the 100 points on the grid in 2-dimensions to points in the original 128-dimensional latent space.
np.dstack((pc1_mesh, pc2_mesh)).shape
z_grid = np.apply_along_axis(pca.inverse_transform,
arr=np.dstack((pc1_mesh, pc2_mesh)),
axis=2)
z_grid.shape
Then, we use these latent codes to reconstruct the corresponding sketch with our decoder, and observe how the sketch transitions between extents of the rectangular region of interest. Of particular interest is how the sketch transitions from left to right, and then top to bottom, as these are the directions that account for the most variance in the latent representations. First, we run our decoder with a relatively low temperature setting $\tau=0.1$ to minimize randomness in our samples.
fig, ax_arr = plt.subplots(nrows=10,
ncols=10,
figsize=(8, 8),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for i, ax_row in enumerate(ax_arr):
for j, ax in enumerate(ax_row):
draw(decode(z_grid[-i-1,j], temperature=.1), ax=ax)
ax.axis('off')
plt.show()
You can definitely observe some interesting transitions and patterns here. In the bottom-right corner, you can see a cluster of similar-looking sheep, appearing sheared and unostentatious. As you move up and to the left, you begin to see some sheep with crudely-drawn circular scribbles for fur. Along the middle rows is where you see the most realistic-looking sheep. In the top-left corner, the sketches begin to get a little uh... shall we say, abstract. Like it was drawn by someone who really longed to see a sheep.
Running it with temperature $\tau=0.6$ doesn't yield much further insight.
fig, ax_arr = plt.subplots(nrows=10,
ncols=10,
figsize=(8, 8),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for i, ax_row in enumerate(ax_arr):
for j, ax in enumerate(ax_row):
draw(decode(z_grid[-i-1,j], temperature=.6), ax=ax)
ax.axis('off')
plt.show()
Eigensheep Decomposition¶
The principal components of the learned latent representation of sheep sketches is often referred to as the eigensheep... by me. And now by you.
Using the decoder, we can visualize the sketch representation of the first 2 eigensheep. These eigensheep are the orthogonal weight vectors that transform the latent representations into the 2-dimensional subspace we've been working in. They are just 128-dimensional vectors which, by themselves, are largely inscrutable. By treating them as latent codes, and reconstructing them as sketches with the decoder, we might be able to distill some meaningful interpretation.
z_pc1, z_pc2 = pca.components_
pca.explained_variance_ratio_
pca.explained_variance_
We draw the reconstructed sketches of the eigensheep at increasing temperatures from $\tau=0.1 \dotsc 1.0$, taking 5 samples for each temperature setting.
fig, ax_arr = plt.subplots(nrows=5,
ncols=10,
figsize=(8, 4),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for row_num, ax_row in enumerate(ax_arr):
for col_num, ax in enumerate(ax_row):
t = (col_num + 1) / 10.
draw(decode(z_pc1, temperature=t), ax=ax)
if row_num + 1 == len(ax_arr):
ax.set_xlabel(r'$\tau={}$'.format(t))
plt.show()
fig, ax_arr = plt.subplots(nrows=5,
ncols=10,
figsize=(8, 4),
subplot_kw=dict(xticks=[],
yticks=[],
frame_on=False))
fig.tight_layout()
for row_num, ax_row in enumerate(ax_arr):
for col_num, ax in enumerate(ax_row):
t = (col_num + 1) / 10.
draw(decode(z_pc2, temperature=t), ax=ax)
if row_num + 1 == len(ax_arr):
ax.set_xlabel(r'$\tau={}$'.format(t))
plt.show()
This reminds me a bit of eigenfaces in facial recognition datasets, where you get dark blurs that highlight the most prominent features of a human face -- the eyes, nose, and mouth. The results are usually terrifying images of unsettling faces with deep, dark, sunken eyes and creepy smiles.
The reconstructions of the eigensheep above at lower temperature setting reveal more interesting patterns. For the first eigensheep, at $\tau=0.1$, you mostly see a pattern of round, black heads with two long, round legs, and some kind of tail (or possibly cape?) This axis seems to be along the direction that consists of variations in the structure of the head and legs of the sheep, and it would make sense that this accounts for most of the variance in the sketches. A look at the variation in the grid of reconstructed sketches above, going from left to right, seems to be consistent with this conjecture.
Looking at the second eigensheep, again concentrating on the samples generated with the lower temperatures, we see some roughly scribbled circles representing the body, with 3-4 loosely attached legs and a tiny round head. Unlike the first principal axis, this one seems to be along the direction that consists of variations primarily in the structure of the body of thye sheep. Again, studying the transitions from top to bottom in the grid above seems consistent with this.
t-SNE Visualization¶
tsne = TSNE(n_components=2, n_iter=5000)
Z_tsne = tsne.fit_transform(Z)
fig, ax = plt.subplots()
ax.scatter(*Z_tsne.T)
ax.set_xlabel('$c_1$')
ax.set_ylabel('$c_2$')
plt.show()
tsne.kl_divergence_
((c1_min, c2_min),
(c1_max, c2_max)) = np.percentile(Z_tsne, q=[5, 95], axis=0)
roi_rect = patches.Rectangle(xy=(c1_min, c2_min),
width=c1_max-c1_min,
height=c2_max-c2_min, alpha=.4)
fig, ax = plt.subplots()
ax.scatter(*Z_tsne.T)
ax.add_patch(roi_rect)
ax.set_xlabel('$c_1$')
ax.set_ylabel('$c_2$')
plt.show()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(c1_min, c1_max)
ax.set_ylim(c2_min, c2_max)
for i, sketch in enumerate(test_set.strokes):
sketch_path = SketchPath(sketch, factor=2.)
sketch_path.vertices[::,1] *= -1
sketch_path.vertices += Z_tsne[i]
patch = patches.PathPatch(sketch_path, facecolor='none')
ax.add_patch(patch)
ax.axis('off')
plt.savefig("../../files/sketchrnn/aaron_sheep_tsne.svg",
format="svg")
You can download the SVG image here.

